{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch \n",
    "import numpy as np\n",
    "import pandas as pd \n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from LipNet import LipNet\n",
    "from LipReadDataTest import ReadData as ReadDataTest"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "===============================================\n",
    "            1.Predict\n",
    "==============================================="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_image_file = os.path.join(os.path.abspath('.'), \"data/lip_test\")\n",
    "test_dataset = ReadDataTest(test_image_file, seq_max_lens=24)\n",
    "test_data_loader = DataLoader(test_dataset, batch_size=20, shuffle=True, num_workers=8, drop_last=False)\n",
    "\n",
    "#GPU\n",
    "device = torch.device('cuda:0')\n",
    "# # CPU\n",
    "# device = torch.device('cpu')\n",
    "\n",
    "model = LipNet().to(device)\n",
    "model.load_state_dict(torch.load(\"./weight/demo_net_epoch_2.pt\"))   \n",
    "model.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    col_key = []\n",
    "    col_pre = []\n",
    "    for i_batch, sample_batched in enumerate(test_data_loader):\n",
    "        \n",
    "        input_data = Variable(sample_batched['volume']).to(device)\n",
    "        length = Variable(sample_batched['length']).to(device)\n",
    "        \n",
    "        # linux\n",
    "        keys =[i.split('/')[-1] for i in sample_batched['key']]\n",
    "#         # windows \n",
    "#         keys =[i.split('\\\\')[-1] for i in sample_batched['key']]\n",
    "\n",
    "        outputs = model(input_data)\n",
    "        average_volumns = torch.sum(outputs.data, 1)\n",
    "        for i in range(outputs.size(0)):\n",
    "            average_volumns[i] = outputs[i, :length[i]].sum(0)\n",
    "        _, max_indexs = torch.max(average_volumns, 1)\n",
    "        max_indexs = max_indexs.cpu().numpy().tolist()\n",
    "        \n",
    "        col_key += keys\n",
    "        col_pre += max_indexs \n"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "===============================================\n",
    "            2.file to submit\n",
    "==============================================="
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "dictionary = pd.read_csv('./dictionary/dictionary.csv', encoding='utf8')\n",
    "word_list = dictionary.dict.tolist()\n",
    "character_label = [word_list[i] for i in col_pre]\n",
    "predict = pd.DataFrame([col_key, character_label]).T\n",
    "predict.to_csv('预测结果.csv',encoding='utf8', index=None, header=None)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
